{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "raised-cradle",
   "metadata": {},
   "source": [
    "# 分子生成"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "gothic-convergence",
   "metadata": {},
   "source": [
    "在这篇教程中，我们将会介绍如何训练一个基于序列的VAE模型去生成分子的SMILES序列。我们将会介绍模型的训练和通过训练好的模型进行采样生成。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "incoming-watch",
   "metadata": {},
   "source": [
    "## 基于序列的VAE"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "informative-departure",
   "metadata": {},
   "source": [
    "分子生成是一种通过在大数据集上训练深层生成模型来生成新分子的流行工具。生成模型可用于设计新分子、探索分子空间等。其生成的分子可进一步用于虚拟筛选或其他下游任务。\n",
    "在这项工作中，我们将介绍一个变分自动编码器（VAE）为基础的生成模型。\n",
    "\n",
    "VAE包含两个神经网络 - 一个编码器和一个解码器。利用这种结构，该模型可以通过编码器将高维输入空间转换为低维的隐空间，并通过解码器将其转换回原始的输入空间。 隐空间是一个正态分布的连续向量空间。我们最小化了Kullback-Leibler（KL）散度损失和重构损失。利用隐空间连续的性质，我们可以利用训练好的VAE模型对新分子进行采样。\n",
    "\n",
    "分子的输入是SMILES序列。通过两者的结合，序列VAE模型将一个SMILES序列作为输入，重构输入序列。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "middle-bulletin",
   "metadata": {},
   "source": [
    "![title](./figures/seq_VAE.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "olympic-stereo",
   "metadata": {},
   "source": [
    "## 部分 I: 训练一个序列VAE模型"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "piano-alloy",
   "metadata": {},
   "source": [
    "### 加载数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "worth-deputy",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import os\n",
    "seq_VAE_path = '../apps/molecular_generation/seq_VAE/'\n",
    "sys.path.insert(0, os.getcwd() + \"/..\")\n",
    "sys.path.append(seq_VAE_path)\n",
    "from utils import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "ceramic-camcorder",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2021-05-13 14:38:50--  https://baidu-nlp.bj.bcebos.com/PaddleHelix/datasets/molecular_generation/zinc_moses.tgz\n",
      "Resolving baidu-nlp.bj.bcebos.com (baidu-nlp.bj.bcebos.com)... 10.70.0.165\n",
      "Connecting to baidu-nlp.bj.bcebos.com (baidu-nlp.bj.bcebos.com)|10.70.0.165|:443... connected.\n",
      "HTTP request sent, awaiting response... "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/luohongyu01/opt/anaconda3/lib/python3.8/site-packages/ipykernel/ipkernel.py:287: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
      "  and should_run_async(code)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "200 OK\n",
      "Length: 8708409 (8.3M) [application/gzip]\n",
      "Saving to: ‘zinc_moses.tgz.1’\n",
      "\n",
      "zinc_moses.tgz.1    100%[===================>]   8.30M  2.54MB/s    in 3.5s    \n",
      "\n",
      "2021-05-13 14:38:54 (2.34 MB/s) - ‘zinc_moses.tgz.1’ saved [8708409/8708409]\n",
      "\n",
      "x zinc_moses/\n",
      "x zinc_moses/.DS_Store\n",
      "x zinc_moses/test.csv\n",
      "x zinc_moses/train.csv\n"
     ]
    }
   ],
   "source": [
    "# download and decompress the data\n",
    "!wget https://baidu-nlp.bj.bcebos.com/PaddleHelix/datasets/molecular_generation/zinc_moses.tgz\n",
    "!tar -zxvf \"zinc_moses.tgz\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "combined-preliminary",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_path = './zinc_moses/train.csv'\n",
    "train_data = load_zinc_dataset(data_path)\n",
    "# get the toy data\n",
    "train_data = train_data[0:1000]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "hindu-acquisition",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1000"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(train_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "rapid-jungle",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['CCCS(=O)c1ccc2[nH]c(=NC(=O)OC)[nH]c2c1',\n",
       " 'CC(C)(C)C(=O)C(Oc1ccc(Cl)cc1)n1ccnc1',\n",
       " 'Cc1c(Cl)cccc1Nc1ncccc1C(=O)OCC(O)CO',\n",
       " 'Cn1cnc2c1c(=O)n(CC(O)CO)c(=O)n2C',\n",
       " 'CC1Oc2ccc(Cl)cc2N(CC(O)CO)C1=O',\n",
       " 'CCOC(=O)c1cncn1C1CCCc2ccccc21',\n",
       " 'COc1ccccc1OC(=O)Oc1ccccc1OC',\n",
       " 'O=C1Nc2ccc(Cl)cc2C(c2ccccc2Cl)=NC1O',\n",
       " 'CN1C(=O)C(O)N=C(c2ccccc2Cl)c2cc(Cl)ccc21',\n",
       " 'CCC(=O)c1ccc(OCC(O)CO)c(OC)c1']"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_data[0:10]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "elder-drive",
   "metadata": {},
   "source": [
    "### 定义语法"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "established-arena",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 基于训练集定义序列的语法\n",
    "vocab = OneHotVocab.from_data(train_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "martial-attribute",
   "metadata": {},
   "source": [
    "### 模型参数设置"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "periodic-confusion",
   "metadata": {},
   "source": [
    "神经网络的参数存储在model_config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 110,
   "id": "organic-bronze",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_config = \\\n",
    "{\n",
    "    \"max_length\":80,     # 序列的最大长度\n",
    "    \"q_cell\": \"gru\",     # 编码器cell的类型\n",
    "    \"q_bidir\": 1,        # 是否编码器是双向RNN\n",
    "    \"q_d_h\": 256,        # 编码器隐空间大小\n",
    "    \"q_n_layers\": 1,     # 编码器RNN层数\n",
    "    \"q_dropout\": 0.5,    # 编码器dropout rate\n",
    "\n",
    "\n",
    "    \"d_cell\": \"gru\",     # 解码器cell类型\n",
    "    \"d_n_layers\":3,      # 解码器RNN层数\n",
    "    \"d_dropout\":0.2,     # 解码器dropout rate\n",
    "    \"d_z\":128,           # VAE隐空间大小\n",
    "    \"d_d_h\":512,         # 解码器隐空间大小\n",
    "    \"freeze_embeddings\":0 # 是否固定one-hot embedding\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "computational-lecture",
   "metadata": {},
   "source": [
    "### 定义模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 111,
   "id": "competent-valuable",
   "metadata": {},
   "outputs": [],
   "source": [
    "# build the model\n",
    "from pahelix.model_zoo.seq_vae_model  import VAE\n",
    "model = VAE(vocab, model_config)  "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "weekly-ebony",
   "metadata": {},
   "source": [
    "### 训练模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 112,
   "id": "subject-clinton",
   "metadata": {},
   "outputs": [],
   "source": [
    "# define the training settings\n",
    "batch_size = 64\n",
    "learning_rate = 0.001\n",
    "n_epoch = 1\n",
    "kl_weight = 0.1\n",
    "\n",
    "# define optimizer\n",
    "optimizer = paddle.optimizer.Adam(parameters=model.parameters(),\n",
    "                            learning_rate=learning_rate)\n",
    "\n",
    "# build the dataset and data loader\n",
    "max_length = model_config[\"max_length\"]\n",
    "train_dataset = StringDataset(vocab, train_data, max_length)\n",
    "train_dataloader = paddle.io.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 113,
   "id": "objective-crash",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "#######################\n",
      "batch:0, kl_loss:0.377259, recon_loss:3.379486\n",
      "batch:1, kl_loss:0.259201, recon_loss:3.264177\n",
      "batch:2, kl_loss:0.210570, recon_loss:3.144137\n",
      "batch:3, kl_loss:0.205814, recon_loss:3.053869\n",
      "batch:4, kl_loss:0.204681, recon_loss:2.960207\n",
      "batch:5, kl_loss:0.205177, recon_loss:2.892930\n",
      "batch:6, kl_loss:0.203757, recon_loss:2.838837\n",
      "batch:7, kl_loss:0.201053, recon_loss:2.782497\n",
      "batch:8, kl_loss:0.197671, recon_loss:2.751050\n",
      "batch:9, kl_loss:0.192766, recon_loss:2.715708\n",
      "batch:10, kl_loss:0.186594, recon_loss:2.684680\n",
      "batch:11, kl_loss:0.179440, recon_loss:2.664472\n",
      "batch:12, kl_loss:0.171974, recon_loss:2.641148\n",
      "batch:13, kl_loss:0.164508, recon_loss:2.620756\n",
      "batch:14, kl_loss:0.157552, recon_loss:2.605232\n",
      "batch:15, kl_loss:0.151044, recon_loss:2.586791\n",
      "epoch:0 loss:2.601895 kl_loss:0.151044 recon_loss:2.586791\n"
     ]
    }
   ],
   "source": [
    "# start to train \n",
    "for epoch in range(n_epoch):\n",
    "    print('#######################')\n",
    "    kl_loss_values = []\n",
    "    recon_loss_values = []\n",
    "    loss_values = []\n",
    "    \n",
    "    for batch_id, data in enumerate(train_dataloader()):\n",
    "        # read batch data\n",
    "        data_batch = data\n",
    "\n",
    "        # forward\n",
    "        kl_loss, recon_loss  = model(data_batch)\n",
    "        loss = kl_weight * kl_loss + recon_loss\n",
    "\n",
    "\n",
    "        # backward\n",
    "        loss.backward()\n",
    "        # optimize\n",
    "        optimizer.step()\n",
    "        # clear gradients\n",
    "        optimizer.clear_grad()\n",
    "        \n",
    "        # gathering values from each batch\n",
    "        kl_loss_values.append(kl_loss.numpy())\n",
    "        recon_loss_values.append(recon_loss.numpy())\n",
    "        loss_values.append(loss.numpy())\n",
    "\n",
    "        \n",
    "        print('batch:%s, kl_loss:%f, recon_loss:%f' % (batch_id, float(np.mean(kl_loss_values)), float(np.mean(recon_loss_values))))\n",
    "        \n",
    "    print('epoch:%d loss:%f kl_loss:%f recon_loss:%f' % (epoch, float(np.mean(loss_values)), float(np.mean(kl_loss_values)),float(np.mean(recon_loss_values))),flush=True)\n",
    "\n",
    "  "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "violent-crowd",
   "metadata": {},
   "source": [
    "## 部分 II: 从正态先验中采样"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 114,
   "id": "respective-teens",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'valid': 0.013000000000000012, 'unique@3': 0.6666666666666666, 'IntDiv': 0.7307692307692307, 'IntDiv2': 0.5181166128686162, 'Filters': 0.9230769230769231}\n"
     ]
    }
   ],
   "source": [
    "from pahelix.utils.metrics.molecular_generation.metrics_ import get_all_metrics\n",
    "N_samples = 1000  # number of samples \n",
    "max_len = 80      # maximum length of samples\n",
    "current_samples = model.sample(N_samples, max_len)  # get the samples from pre-trained model\n",
    "\n",
    "metrics = get_all_metrics(gen=current_samples, k=[3])  # get the evaluation from samples\n",
    "print(metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "divine-still",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
